Skip to content

backport: feature-detect DSA experimental-attention spec for GLM-5/5.1/5.2#13

Open
yushengsu-thu wants to merge 5 commits into
bridgefrom
bridge-dev-glm
Open

backport: feature-detect DSA experimental-attention spec for GLM-5/5.1/5.2#13
yushengsu-thu wants to merge 5 commits into
bridgefrom
bridge-dev-glm

Conversation

@yushengsu-thu

@yushengsu-thu yushengsu-thu commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

What

Make the radixark/Megatron-Bridge GLM-5 / 5.1 / 5.2 (glm_moe_dsa) bridge fully build on the radixark/miles image's megatron-core (0.16.0rc0), for both LoRA and full-FT paths, with no megatron-core source change. Covers GLM-5.1 (DSA) and GLM-5.2 (DSA + cross-layer index sharing + rope 8e6).

Files: src/megatron/bridge/models/glm_moe_dsa/cross_layer_dsa.py (new), src/megatron/bridge/models/glm_moe_dsa/glm5_bridge.py.

Commits

e4cd5f37 — GLM-5.1: feature-detect the DSA experimental-attention spec

glm5_bridge.py sets experimental_attention_variant="dsa". On older megatron-core the dispatcher (get_experimental_attention_variant_module_spec) only wires "gated_delta_net" and raises ValueError for "dsa", and get_dsa_module_spec_for_backend omits the metainfo the variant layer-builder reads → the model can't build. transformer_layer_spec now points at _build_glm5_dsa_block_spec, whose wrapped dispatcher prefers megatron-core's native handling and only on the "dsa" ValueError (old core) back-fills via the shipped builder + sets metainfo["fuse_input_layernorm"]=False (MLA-based DSA keeps a separate, non-fused input layernorm, like the deepseek_v4 dsv4 spec). ⇒ transparent no-op on newer megatron-core; deletable after a core bump. (Previously a caller-side monkey-patch in miles bridge_lora_helpers.py, now consolidated here.)

2bdfa05d — GLM-5.2: parse the MLA rope dims (rotary_base + qk_pos_emb_head_dim)

GLM-5.2 sets rope_theta=8e6, and under transformers>=5.12 the parsed GlmMoeDsaConfig reports qk_rope_head_dim as head_dim (192) instead of the config.json value (64). The base config-mapping then sizes MLA linear_kv_down_proj as kv_lora_rank + 192 = 704, contradicting the checkpoint (kv_a_proj_with_mqa = kv_lora_rank + qk_rope_head_dim = 576 = 512 + 64).

  • rotary_base: read rope_theta whether nested in rope_parameters or flat.
  • qk_pos_emb_head_dim: re-read qk_rope_head_dim straight from config.json so the MLA rope/kv dims match the weights. No-op when the parse is already correct; GLM-5.1 unaffected (its head_dim already equals qk_rope_head_dim = 64).

74dade06 — GLM-5.2: DSA cross-layer index sharing (CrossLayerDSAttention)

GLM-5.2 keeps GLM-5.1's glm_moe_dsa arch, but only "computing" layers carry the lightning indexer and compute the sparse top-k; "skip" layers reuse the most recent computing layer's top-k (HF config index_topk_freq=4, index_skip_topk_offset=3 → computing Megatron-layers 1,2,3,7,11,…). megatron-core's DSA is per-layer only, so this adds a Bridge-owned CrossLayerDSAttention(DSAttention):

  • anchor layers compute + publish topk_indices to a per-microbatch holder (packed_seq_params for thd, thread-local for bshd);
  • skip layers del self.indexer (so the param set matches the subset checkpoint, which only stores indexer weights on computing layers) and reuse the source anchor's top-k via unfused_dsa_fn.

get_glm5_crosslayer_dsa_spec calls megatron-core's exact get_dsa_module_spec_for_backend and swaps only submodules.core_attention.module to the subclass (so the MLA structure — fused qk-layernorm, indexer submodules — is inherited verbatim). Feature-gated in _build_glm5_dsa_block_spec on dsa_index_topk_freq > 1, so GLM-5.1 (no freq → 1) keeps the existing per-layer path byte-for-byte.

Verified (via radixark/miles, 4×H200, bridge mode, TP4/EP4, bshd, mbs1)

  • GLM-5.2 7-layer train-only (jybsuper/GLM-5.2-7layer): build CrossLayerDSAttention → load the 98 GB subset checkpoint cleanly (proves the indexer is built only on computing layers 1,2,3,7) → cross-layer fwd/bwd (skip layers reuse anchor top-k, no holder assert) → save LoRA adapter → TRAIN EXIT 0.
  • GLM-5.1 6-layer full e2e (rollout → train → save): TRAIN EXIT 0 — the feature-gate leaves the GLM-5.1 path unchanged (regression check).

Notes

  • No megatron-core source change; everything is on the Bridge side and self-disables on a newer megatron-core.
  • sglang does not yet serve the GLM-5.2 cross-layer rollout, so the GLM-5.2 validation is training-side (train-only); GLM-5.1 runs the full rollout→train loop.

🤖 Generated with Claude Code

Older megatron-core (e.g. the radixark/miles image's 0.16.0rc0) only wires
"gated_delta_net" in get_experimental_attention_variant_module_spec and raises
ValueError for "dsa", and its get_dsa_module_spec_for_backend omits the metainfo
the variant layer-builder reads. The GLM-5/5.1 bridge sets
experimental_attention_variant="dsa" + transformer_layer_spec to that builder, so
the model fails to build on such a core (LoRA and full-FT bridge paths alike).

Wrap transformer_layer_spec in _build_glm5_dsa_block_spec: PREFER megatron-core's
native handling, and only when it raises for "dsa" back-fill via the shipped DSA
builder + set metainfo["fuse_input_layernorm"]=False (MLA-based DSA keeps a
separate, non-fused input layernorm, like the deepseek_v4 dsv4 spec). On newer
megatron-core (which handles "dsa" natively + sets metainfo) this is a transparent
no-op, so the helper self-disables and can be deleted once the runtime core is bumped.

Same spirit as the other miles-compat backports (mimo.config.role, training.config,
parse_hybrid_pattern). Verified e2e: GLM-5.1 6-layer GRPO LoRA via bridge with no
caller-side patch -> Job succeeded + PEFT adapter saved.

Signed-off-by: Yusheng Su <yushengsu.thu@gmail.com>
…dim)

GLM-5.2 sets rope_theta=8e6, and under transformers>=5.12 the parsed
GlmMoeDsaConfig reports qk_rope_head_dim as head_dim (192) instead of the
config.json value (64). The base config-mapping then sized MLA
linear_kv_down_proj as kv_lora_rank + 192 = 704, contradicting the checkpoint
(kv_a_proj_with_mqa = kv_lora_rank + qk_rope_head_dim = 576 = 512 + 64).

- rotary_base: read rope_theta whether nested in rope_parameters or flat.
- qk_pos_emb_head_dim: re-read qk_rope_head_dim straight from config.json so
  MLA rope/kv dims match the weights. No-op when the parse is already correct;
  GLM-5.1 is unaffected (its head_dim already equals qk_rope_head_dim = 64).

Signed-off-by: Yusheng Su <yushengsu.thu@gmail.com>
GLM-5.2 keeps GLM-5.1's glm_moe_dsa arch but only "computing"/anchor layers
carry the lightning indexer and compute the sparse top-k; "skip" layers reuse
the most recent computing layer's top-k (HF config index_topk_freq>1 +
index_skip_topk_offset). megatron-core's DSA is per-layer only, so this adds a
Bridge-owned CrossLayerDSAttention(DSAttention): anchors publish topk_indices to
a per-microbatch holder (packed_seq_params for thd, thread-local for bshd), skip
layers drop their indexer (matching the subset checkpoint) and reuse the source
anchor's top-k. get_glm5_crosslayer_dsa_spec calls megatron-core's exact
get_dsa_module_spec_for_backend and only swaps core_attention.module.

Feature-gated in _build_glm5_dsa_block_spec on dsa_index_topk_freq>1, so GLM-5.1
(no freq -> 1) keeps the existing per-layer path unchanged. No megatron-core
edits. Validated: GLM-5.2 7-layer train-only e2e (build + 98G subset-ckpt load +
cross-layer fwd/bwd + LoRA adapter) and GLM-5.1 6-layer full e2e regression both
reach TRAIN EXIT 0.

Signed-off-by: Yusheng Su <yushengsu.thu@gmail.com>
@yushengsu-thu yushengsu-thu changed the title backport: feature-detect DSA experimental-attention spec for GLM-5/5.1 backport: GLM-5.1/5.2 glm_moe_dsa bridge — DSA spec feature-detect + GLM-5.2 cross-layer index sharing Jun 20, 2026
@yushengsu-thu yushengsu-thu changed the title backport: GLM-5.1/5.2 glm_moe_dsa bridge — DSA spec feature-detect + GLM-5.2 cross-layer index sharing backport: feature-detect DSA experimental-attention spec for GLM-5/5.1/5.2 Jun 20, 2026
… layer (build time)

The per-microbatch top-k holder used for DSA cross-layer index sharing does NOT cross
pipeline boundaries, so a skip layer's source computing layer must live in the same PP
stage. Previously a bad (virtual) pipeline split that started a stage on a skip layer was
only caught by the runtime guard in CrossLayerDSAttention.forward (first forward of that
layer). This adds a build-time check mirroring slime's get_glm5_spec: it fails at model
construction with a precise message ("stage starts at global layer_number=X which is a skip
layer whose source computing layer=Y is on a previous stage").

- cross_layer_dsa.py: new assert_pp_stage_starts_on_computing_layer(config, vp_stage); uses
  get_transformer_layer_offset + is_skip_topk_layer. No-op unless dsa_index_topk_freq>1, and
  silently returns if the layout can't be determined (runtime guard remains the backstop).
- glm5_bridge.py: _build_glm5_dsa_block_spec calls it (gated on dsa + freq>1) before building
  the block, forwarding vp_stage.

GLM-5.1 (freq=1) and valid PP=1 layouts are unaffected (no-op). Verified: raises on a stage
starting at a skip layer, no-op on a computing-layer start / GLM-5.1, and the real PP=1
GLM-5.2 build still passes.

Signed-off-by: Yusheng Su <yushengsu.thu@gmail.com>
…older not recompute-safe)

The per-microbatch top-k holder used for DSA cross-layer index sharing rides on
packed_seq_params in the thd layout (closure-captured by the activation-checkpoint
custom_forward, so it survives recompute). In the bshd layout packed_seq_params is None and the
holder falls back to a process thread-local dict, which is NOT recompute-safe: under activation
recompute a skip layer's recompute can read a stale anchor top-k (the dict is not captured per
microbatch), silently corrupting gradients.

CrossLayerDSAttention now records whether activation recompute is configured
(self._recompute_active, from config.recompute_granularity) and, on a cross-layer forward with
packed_seq_params is None, raises a clear AssertionError directing the user to --qkv-format thd
(recompute-safe) or to disable activation recompute. No-op for thd, for no-recompute, and for
GLM-5.1 (index_topk_freq=1). Training currently uses thd, so this only guards the unsafe
bshd + recompute + cross-layer combination.

Signed-off-by: Yusheng Su <yushengsu.thu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant